#include <cuda_runtime.h>
#include <iostream>

#define ceil(a, b) (((a)+(b)-1)/(b))
#define CHECK_ERROR(func) {        \
    do {                           \
        cudaError_t e = (func);    \
        if (cudaSuccess != e) {    \
            printf("%s %d CUDA: %s\n", __FILE__, __LINE__, cudaGetErrorString(e));\
        }                          \
    }                              \
}

template<int BLOCK_DIM, int D>
__global__ void flash_attention(float* Q, float* K, float* V,
                                unsigned int N, unsigned int d,
                                float* max, float* sum, float* O){
    __shared__ float Q_sm[BLOCK_DIM][D];
    __shared__ float K_sm[BLOCK_DIM][D];
    __shared__ float V_sm[BLOCK_DIM][D];

    __shared__ float block_max[BLOCK_DIM];
    __shared__ float block_old_max[BLOCK_DIM];
    __shared__ float block_sum[BLOCK_DIM];

    __shared__ float grid_max[BLOCK_DIM];
    __shared__ float grid_sum[BLOCK_DIM];
    
    block_max[threadIdx.x] = -__FLT_MAX__;
    block_old_max[threadIdx.x] = -__FLT_MAX__;
    block_sum[threadIdx.x] = 0.0f;

    __shared__ float S[BLOCK_DIM][BLOCK_DIM];
    S[threadIdx.y][threadIdx.x] = 0.0f;
    
    K = K + blockIdx.y * d * BLOCK_DIM;
    V = V + blockIdx.y * d * BLOCK_DIM;
    
    int loop_times = (d + BLOCK_DIM - 1) / BLOCK_DIM;
    #pragma unroll
    for (int k = 0; k < loop_times; k++) {
        K_sm[threadIdx.y][k * BLOCK_DIM + threadIdx.x] = K[threadIdx.y * d + k * BLOCK_DIM + threadIdx.x];
        V_sm[threadIdx.y][k * BLOCK_DIM + threadIdx.x] = V[threadIdx.y * d + k * BLOCK_DIM + threadIdx.x];
    }
    __syncthreads();

    int q_loop_times = (N + BLOCK_DIM - 1) / BLOCK_DIM;
    #pragma unroll
    for (int i = 0; i < q_loop_times; i++) {
        Q = Q + i * d * BLOCK_DIM;
        #pragma unroll
        for (int k = 0; k < loop_times; k++) {
            Q_sm[threadIdx.y][k * BLOCK_DIM + threadIdx.x] = Q[threadIdx.y * d + k * BLOCK_DIM + threadIdx.x];
        }
        grid_max[threadIdx.x] = max[i * BLOCK_DIM + threadIdx.x];
        grid_sum[threadIdx.x] = sum[i * BLOCK_DIM + threadIdx.x];

        #pragma unroll
        for (int x = 0; x < D; x++) {
            S[threadIdx.y][threadIdx.x] += Q_sm[threadIdx.y][x] * K_sm[threadIdx.x][x];
        }

        __syncthreads();

        if (threadIdx.y == 0) {
            for (int k = 0; k < BLOCK_DIM; k++) {
                if (S[threadIdx.y][k] > block_max[threadIdx.y]) block_max[threadIdx.y] = S[threadIdx.y][k];
            }
        }
    }
}

int main() {
    unsigned int N = 2048;
    unsigned int d = 128;
    unsigned int byte_size = N * d * sizeof(float);

    float *Q_host, *K_host, *V_host, *O_host;
    float *Q_dev, *K_dev, *V_dev, *O_dev;
    
    Q_host = (float*)malloc(byte_size);
    K_host = (float*)malloc(byte_size);
    V_host = (float*)malloc(byte_size);
    O_host = (float*)malloc(byte_size);

    for (size_t i = 0; i < N * d; i++) {
        Q_host[i] = 1;
        K_host[i] = 1;
        V_host[i] = 1;
    }

    CHECK_ERROR(cudaMalloc((void**)&Q_dev, byte_size));
    CHECK_ERROR(cudaMalloc((void**)&K_dev, byte_size));
    CHECK_ERROR(cudaMalloc((void**)&V_dev, byte_size));
    CHECK_ERROR(cudaMalloc((void**)&O_dev, byte_size));

    CHECK_ERROR(cudaMemcpy(Q_dev, Q_host, byte_size, cudaMemcpyHostToDevice));
    CHECK_ERROR(cudaMemcpy(K_dev, K_host, byte_size, cudaMemcpyHostToDevice));
    CHECK_ERROR(cudaMemcpy(V_dev, V_host, byte_size, cudaMemcpyHostToDevice));

    int BLOCK_DIM = 16;
    dim3 grid(1, ceil(N,BLOCK_DIM));
    dim3 block(BLOCK_DIM, BLOCK_DIM);
    flash_attention<<<grid, block>>>(Q_dev, K_dev, V_dev, N, d, O_dev);

    CHECK_ERROR(cudaMemcpy(O_host, O_dev, byte_size, cudaMemcpyDeviceToHost));

    free(Q_host);
    free(K_host);
    free(V_host);
    free(O_host);
    cudaFree(Q_dev);
    cudaFree(K_dev);
    cudaFree(V_dev);
    cudaFree(O_dev);

    std::cout << O_host[0] << "\n";

    return 0;
}